#include <os/config.h>

#if SWAP_ENABLE

#include <os/kernelif.h>
#include <os/swap.h>
#include <os/debug.h>
#include <os/timer.h>
#include <os/safety.h>
#include <os/diskio.h>
#include <os/diskman.h>
#include <arch/swapswitch.h>
#include <arch/page.h>
#include <lib/string.h>
#include <lib/stddef.h>
#include <lib/unistd.h>
#include <lib/errno.h>
#include <sys/fcntl.h>
#include <sys/swap.h>
#include <sys/ioctl.h>
#include <arch/memory.h>

#define DEBUG_SWAP

swap_header_t *swap = NULL;
uint8_t *bitmap = NULL;
int fd; // global swap file fd

// load swap file data from disk
#ifdef SWAP_FILE_MODE
static int SwapLoad()
{
    KPrint("[swap] ready to load swap file\n");

    if (KFileAccess(SWAP_FILE_NAME, O_RDONLY) < 0) // file no exist
    {
        fd = KFileOpen(SWAP_FILE_NAME, O_CREATE | O_RDWR); // create file
        if (fd < 0)
        {
            KPrint("create swap file %s failed!\n", SWAP_FILE_NAME);
            return -1;
        }
        // file no exist and create new
        uint8_t buff[SECTOR_SIZE] = {0};
        uint64_t count = SWAP_FILE_SIZE / SECTOR_SIZE + 1;

        // fill file to size
        for (int i = 0; i < count; i++)
        {
            KFileWrite(fd, buff, SECTOR_SIZE);
        }

        // init swap file header
        swap->total_size = SWAP_FILE_SIZE;
        swap->free_off = PAGE_SIZE;
        swap->free_size = (SWAP_FILE_SIZE - swap->free_off) & PAGEALIGN_MASK;
        swap->bitmap_size = ((swap->free_size / PAGE_SIZE) - 1) / 8;
        swap->used_size = 0;
        swap->last_alloc = 0;
        swap->magic = SWAP_MAGIC;

        SwapDump();

        if (KFileLSeek(fd, 0, SEEK_SET) < 0)
        {
            KPrint("swap file %s seek failed!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }

        // write header
        if (KFileWrite(fd, swap, sizeof(swap_header_t)) < 0)
        {
            KPrint("when create swap file %s,write swap header failed!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }

        // create bitmap and init
        bitmap = KMemAlloc(swap->bitmap_size);
        memset(bitmap, 0, swap->bitmap_size);

        // write header
        if (KFileWrite(fd, bitmap, sizeof(bitmap)) < 0)
        {
            KPrint("swap file %s init bitmap area init!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }

        KFileClose(fd);
    }
    else
    {
        // file exist
        fd = KFileOpen(SWAP_FILE_NAME, O_RDONLY);
        if (fd < 0)
        {
            KPrint("swap file %s open failed!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }

        // read header
        if (KFileRead(fd, swap, sizeof(swap_header_t)) != sizeof(swap_header_t))
        {
            KPrint("swap file %s read swap header failed!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }

        // get swap file status
        status_t status;
        uint32_t size;
        if (KFileStatus(SWAP_FILE_NAME, &status) < 0)
        {
            KPrint("swap file %s get file status failed!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }
        size = status.st_size;

        // swap file header error check
        if (swap->magic != SWAP_MAGIC || size != SWAP_FILE_SIZE || swap->total_size != size || !swap->bitmap_size)
        {
            KPrint("swap file %s read swap header errror! file maybe had boken!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }

        // read bitmap from swap file
        if (KFileRead(fd, bitmap, swap->bitmap_size) < 0)
        {
            KPrint("swap file %s read bitmap area error!\n file maybe had boken!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }
        KFileClose(fd);
        KPrint("swap file %s load from disk finished! file size=%d\n", SWAP_FILE_NAME, SWAP_FILE_SIZE);
    }
    return 0;
}
#else
static int SwapLoad()
{
    KPrint("[swap] ready to load swap config\n");

    char buff[SECTOR_SIZE];
    memset(buff, 0, SECTOR_SIZE);

    int solt = DiskSoltFindByPath(SWAP_BLOCK_DEVICE);
    if (solt < 0)
    {
        KPrint("[swap] swap %s device no found\n", SWAP_BLOCK_DEVICE);
        return -1;
    }
    if (DiskOpen(solt) < 0)
    {
        KPrint("[swap] swap %s device open err\n", SWAP_BLOCK_DEVICE);
        return -1;
    }
    if (DiskRead(solt, buff, 0, 1) < 0)
    {
        KPrint("[swap] swap %s device read err\n", SWAP_BLOCK_DEVICE);
        return -1;
    }
    // init global varities point
    swap_header_t *_swap = (uint8_t *)buff;
    uint8_t *_bitmap = (uint8_t *)buff + sizeof(swap_header_t);

    int create = 0;

    if (_swap->magic != SWAP_MAGIC || _swap->total_size == 0 || _swap->bitmap_size == 0)
    {
        create = 1;
    }

    // if (create == 1)
    if (1)
    {
        KPrint("[swap] create swap system\n");
        int size = 0;
        if (DiskIoCtl(solt, DISKIO_GETSIZE, &size) < 0)
        {
            KPrint("[swap] swap %s device get size err\n", SWAP_BLOCK_DEVICE);
            return -1;
        }
#ifdef DEBUG_SWAP
        KPrint("[swap] swap %s device size: %d\n", SWAP_BLOCK_DEVICE, size);
#endif

        _swap->total_size = size & PAGEALIGN_MASK;
        _swap->free_off = PAGE_SIZE;
        _swap->free_size = PAGE_ALIGN(size - _swap->free_off);
        _swap->bitmap_size = ((_swap->free_size / PAGE_SIZE) - 1) / 8;
        _swap->used_size = 0;
        _swap->last_alloc = 0;
        _swap->magic = SWAP_MAGIC;

        memset(_bitmap, 0, _swap->bitmap_size);

        if (DiskWrite(solt, buff, 0, 1) < 0)
        {
            KPrint("[swap] write swap info err\n");
            return -1;
        }
    }
    DiskClose(solt);

    memcpy(swap, _swap, sizeof(swap_header_t));
    bitmap = KMemAlloc(_swap->bitmap_size);
    if (!bitmap)
        return -1;
    memcpy(bitmap, _bitmap, _swap->bitmap_size);
#ifdef DEBUG_SWAP
    SwapDump();
#endif
    return 0;
}
#endif

static inline int SwapClearBitmap(uint64_t bit)
{
    bitmap[bit / 8] &= ~(1 << bit); // clear disk page alloc flags
    swap->free_size += PAGE_SIZE;
    swap->used_size -= PAGE_SIZE;
}

static inline int SwapSetBitmap(uint64_t bit)
{
    bitmap[bit / 8] |= (1 << bit); // set disk page alloc flags
    swap->free_size -= PAGE_SIZE;
    swap->used_size += PAGE_SIZE;
}

#ifdef SWAP_FILE_MODE
static int SwapUpdate()
{
    int fd = KFileOpen(SWAP_FILE_NAME, O_RDWR);
    if (fd < 0)
        return -1;

    if (KFileLSeek(fd, 0, SEEK_SET) < 0)
    {
        KPrint("swap file %s seek failed when update!\n", SWAP_FILE_NAME);
        KFileClose(fd);
        return -1;
    }
    if (KFileWrite(fd, swap, sizeof(swap_header_t)) != sizeof(swap_header_t)) // update swap header
    {
        KPrint("swap file %s update header failed!\n", SWAP_FILE_NAME);
        KFileClose(fd);
        return -1;
    }
    if (KFileWrite(fd, bitmap, swap->bitmap_size) != swap->bitmap_size) // update bitmap
    {
        KPrint("swap file %s update bitmap failed!\n", SWAP_FILE_NAME);
        KFileClose(fd);
        return -1;
    }
    KFileClose(fd);
    return 0;
}
#else
static int SwapUpdate()
{
    int solt = DiskSoltFindByPath(SWAP_BLOCK_DEVICE);
    if (solt < 0)
    {
        KPrint("[swap] device no found\n");
        return -1;
    }
    if (swap->magic != SWAP_MAGIC)
    {
        KPrint("[swap] magic no valid!\n");
        return -1;
    }
    if (DiskOpen(solt) < 0)
    {
        KPrint("[swap] device open err\n");
        return -1;
    }
    uint8_t buff[SECTOR_SIZE];
    memset(buff, 0, SECTOR_SIZE);

    memcpy(buff, swap, sizeof(swap_header_t));
    memcpy(buff + sizeof(swap_header_t), bitmap, swap->bitmap_size);

    if (DiskWrite(solt, buff, 0, 1) < 0)
    {
        KPrint("[swap] write device err\n");
        return -1;
    }
    DiskClose(solt);
}
#endif

int SwapInit()
{
    // init page info
    // SwapSwitchInit();

    // alloc swap header
    swap = KMemAlloc(sizeof(swap_header_t));
    if (!swap)
        return -1;
    // clear swap header
    memset(swap, 0, sizeof(swap_header_t));
    // load swap file/device
    if (SwapLoad() < 0)
    {
        KPrint("load swap file %s failed!\n", SWAP_FILE_NAME);
        return -1;
    }
    KPrint("[swap] init ok\n");
}

#ifdef SWAP_FILE_MODE
// swap in a page from disk to mem
int SwapIn(uint64_t addr)
{
    uint32_t pypage = AllocKernelPage(1);
    if (!pypage)
    {
        KPrint("[swap] alloc pypage err\n");
        return -1;
    }
    pde_t *pde = GetPdeVptr(addr);
    pte_t *pte = GetPteVptr(addr);

    if ((*pde & PAGE_PRESENT) && !(*pte & PAGE_PRESENT))
    {
        uint64_t diskoff = (*pte >> 12) & PAGEBASE_MASK; // get page offset in disk file
        uint64_t diskidx = diskoff / PAGE_SIZE;
        if (diskoff >= SWAP_FILE_SIZE)
            return -1;

        if (!SwapTestBitmap(diskidx))
            return -1;

        if (diskoff > swap->free_size || diskoff + PAGE_SIZE > swap->free_size) // check
        {
            KPrint("swap file %s no enough space!\n", SWAP_FILE_NAME);
            return -1;
        }

        uint64_t off = swap->free_off + diskoff;

        KPrint("swap in: read disk off %d\n", off);

        int fd = KFileOpen(SWAP_FILE_NAME, O_RDONLY);
        if (fd < 0)
        {
            KPrint("[swap] swap in: open file err\n");
            return -1;
        }

        if (KFileLSeek(fd, off, SEEK_SET) < 0) // set to assign file offset
        {
            KPrint("swap file %s seek failed!\n", SWAP_FILE_NAME);
            KFileClose(fd);
            return -1;
        }

        // update pte to pypage and set page present bit
        *pte &= ~PAGEBASE_MASK;
        *pte |= pypage & PAGEBASE_MASK;
        *pte |= PAGE_PRESENT;

        // must map address to new pypage
        if (KFileRead(fd, PTYPE(addr), PAGE_SIZE) != PAGE_SIZE) // read a disk page to mem
        {
            KPrint("swap file %s swap in page from off %d to mem %x failed!\n", SWAP_FILE_NAME, off, pypage);
            KFileClose(fd);
            return -1;
        }

        KFileClose(fd);

        SwapFreeBitmap(diskidx); // remember cleck flags after swap out
        return 0;
    }
    else
    {
        KPrint("[swap] swap in: virtual page %x present\n", addr);
        return -1;
    }
}
#else
// swap in a page from disk to mem
int SwapIn(uint64_t addr)
{
    uint32_t pypage = AllocKernelPage(1);
    if (!pypage)
    {
        KPrint("[swap] alloc pypage err\n");
        return -1;
    }
    KPrint("[swap in] alloc pypage at %x\n", pypage);
    pde_t *pde = GetPdeVptr(addr);
    pte_t *pte = GetPteVptr(addr);

    if ((*pde & PAGE_PRESENT) && !(*pte & PAGE_PRESENT))
    {
        uint64_t diskoff = (*pte >> 12) & PAGEBASE_MASK; // get page offset in disk file
        uint64_t diskidx = diskoff / PAGE_SIZE;
        if (diskoff >= swap->total_size)
            return -1;

        if (!SwapTestBitmap(diskidx))
            return -1;

        if (diskoff > swap->free_size || diskoff + PAGE_SIZE > swap->free_size) // check
        {
            KPrint("swap file %s no enough space!\n", SWAP_FILE_NAME);
            return -1;
        }

        uint64_t off = swap->free_off + diskoff;

        KPrint("swap in: read disk off %d off %d\n", diskoff, off);

        int solt = DiskSoltFindByPath(SWAP_BLOCK_DEVICE);
        if (solt < 0)
        {
            KPrint("[swap] found device err\n");
            return -1;
        }

        if (DiskOpen(solt) < 0)
        {
            KPrint("[swap] open device err\n");
            return -1;
        }

        // update pte to pypage and set page present bit
        wmb();
        *pte &= ~PAGEBASE_MASK;
        *pte |= pypage & PAGEBASE_MASK;
        *pte |= PAGE_PRESENT;
        wmb();

        // must map address to new pypage
        if (DiskRead(solt, addr, off / SECTOR_SIZE, PAGE_SIZE / SECTOR_SIZE) < 0)
        {
            KPrint("[swap] read data err\n");
            return -1;
        }

        DiskClose(solt);

        SwapFreeBitmap(diskidx); // remember cleck flags after swap out
        return 0;
    }
    else
    {
        KPrint("[swap] swap in: virtual page %x present\n", addr);
        return -1;
    }
}
#endif

#ifdef SWAP_FILE_MODE
// swap out a page from memory to disk
int SwapOut(uint64_t addr)
{
    uint64_t page;
    pde_t *pde = GetPdeVptr(addr);
    pte_t *pte = GetPteVptr(addr);

    if ((*pde & PAGE_PRESENT) && (*pte & PAGE_PRESENT))
    {
        page = *pte & PAGEBASE_MASK;

        uint64_t diskidx = SwapAllocBitmap(); // alloc free disk page
        if (diskidx < 0)
        {
            KPrint("swap file %s no free page!\n", SWAP_FILE_NAME);
            return -1;
        }

        uint64_t diskoff = diskidx * PAGE_SIZE;
        uint64_t off = swap->free_off + diskoff;

        if (diskoff > swap->free_size || diskoff + PAGE_SIZE > swap->free_size) // check
        {
            KPrint("swap file %s no enough space!\n", SWAP_FILE_NAME);
            return -1;
        }

        int fd = KFileOpen(SWAP_FILE_NAME, O_WRONLY);
        if (fd < 0)
        {
            KPrint("[swap] open file failed\n");
            return -1;
        }

        KPrint("swap out: write off %d\n", off);

        if (KFileLSeek(fd, off, SEEK_SET) < 0)
        {
            KPrint("swap file %s seek to %d failed!\n", SWAP_FILE_NAME, off);
            KFileClose(fd);
            return -1;
        }

        if (KFileWrite(fd, PTYPE(addr), PAGE_SIZE) != PAGE_SIZE) // write pypage to disk
        {
            KPrint("swap file %s write page from mem %x to off %d failed!\n", SWAP_FILE_NAME, page, off);
            KFileClose(fd);
            return -1;
        }
        KFileClose(fd);

        // free pypage
        FreePage(page);
        // update pte to diskoff and clear page present bit
        *pte &= ~PAGE_PRESENT;
        *pte &= ~PAGEBASE_MASK;
        *pte |= (diskoff << 12) & PAGEALIGN_MASK;
        return 0;
    }
    else
    {
        KPrint("[swap] request a no present page\n");
        return -1;
    }
}
#else
// swap out a page from memory to disk
int SwapOut(uint64_t addr)
{
    uint64_t page;
    pde_t *pde = GetPdeVptr(addr);
    pte_t *pte = GetPteVptr(addr);

    if ((*pde & PAGE_PRESENT) && (*pte & PAGE_PRESENT))
    {
        page = *pte & PAGEBASE_MASK;

        uint64_t diskidx = SwapAllocBitmap(); // alloc free disk page
        if (diskidx < 0)
        {
            KPrint("swap device %s no free page!\n", SWAP_BLOCK_DEVICE);
            return -1;
        }

        uint64_t diskoff = diskidx * PAGE_SIZE;
        uint64_t off = swap->free_off + diskoff;

        KPrint("[swap] out: diskoff %x off %x\n", diskoff, off);

        if (diskoff > swap->free_size || diskoff + PAGE_SIZE > swap->free_size) // check
        {
            KPrint("swap device %s no enough space!\n", SWAP_BLOCK_DEVICE);
            return -1;
        }

        int solt = DiskSoltFindByPath(SWAP_BLOCK_DEVICE);
        if (solt < 0)
        {
            KPrint("[swap] found device err\n");
            return -1;
        }

        if (DiskOpen(solt) < 0)
        {
            KPrint("[swap] open device err\n");
            return -1;
        }

        KPrint("swap out: write off %d\n", off);
        if (DiskWrite(solt, addr, off / SECTOR_SIZE, PAGE_SIZE / SECTOR_SIZE) < 0)
        {
            KPrint("[swap] write disk err\n");
            return -1;
        }

        DiskClose(solt);

        // free pypage
        FreePage(page);

        wmb();
        // update pte to diskoff and clear page present bit
        *pte &= ~PAGE_PRESENT;
        *pte &= ~PAGEBASE_MASK;
        *pte |= (diskoff << 12) & PAGEALIGN_MASK;
        wmb();
        return 0;
    }
    else
    {
        KPrint("[swap] request a no present page\n");
        return -1;
    }
}
#endif

int SwapTrySwapOut()
{
    address_t address = SwapSwitchPage(); // swap a page to used tu swap
    if (address == 0xffffffff)
        return -1;
    SwapOut(address);
}

int SwapTrySwapIn(address_t address)
{
    // only support user page swap
    SwapIn(address);
}

// alloc bitmap free bit
int SwapAllocBitmap()
{
    uint64_t i;
    uint64_t idx;
    uint64_t bits;

    if (swap->last_alloc) // no first alloc
    {
        idx = swap->last_alloc + 1; // check last alloc next bits whether is free
        if (!(bitmap[idx / 8] & (1 << idx)))
        {
            if (SwapSetBitmap(idx) < 0)
                return -1;
            SwapUpdate();
            swap->last_alloc = idx;
            return idx;
        }
    }
    else // first alloc                                                                                           l;;;;;..l..l..llllll
    {
        idx = 0;
    }

    // find a free bits for bitmap
    for (i = idx, bits = 0; bits < swap->bitmap_size * 8; i++, bits++)
    {
        if (!(bitmap[i / 8] & (1 << i))) // set free bit to alloc status
        {
            if (SwapSetBitmap(i) < 0)
                return -1;
            SwapUpdate();
            swap->last_alloc = i;
            return i;
        }
    }

    return -1;
}

// free bitmap use bit
int SwapFreeBitmap(uint64_t bit)
{
    if (bitmap[bit / 8] & (1 << bit)) // test whether bit is set
    {
        SwapClearBitmap(bit);
    }

    SwapUpdate();
}

int SwapTestBitmap(uint64_t bit)
{
    return (bitmap[bit / 8] & (1 << bit));
}

void SwapDump()
{
    #ifdef SWAP_FILE_MODE
    KPrint("Swap file %s dump\n", SWAP_FILE_NAME);
    #else 
    KPrint("swap device %s dump\n",SWAP_BLOCK_DEVICE);
    #endif

    KPrint("file size %d bitmap size %d\n", (uint32_t)swap->total_size, (uint32_t)swap->bitmap_size);
    KPrint("alloc status:\n");
    KPrint("total size=%d used size=%d free size=%d\n", (uint32_t)swap->total_size, (uint32_t)swap->used_size, (uint32_t)swap->free_size);
    KPrint("---bitmap dump----\n");
    for (int i = 0; i < swap->bitmap_size; i++)
    {
        KPrint("0x%x ", bitmap[i]);
        if ((i+1) % 5==0)
            KPrint("\n");
    }
}

// swap mem info get
int SysSwapMem(swap_info_t *swap_info)
{
    if (SafetyCheckRange(swap_info, sizeof(swap_info_t)) < 0)
    {
        return -EINVAL;
    }
    swap_info->total_size = swap->total_size;
    swap_info->free_size = swap->free_size;
    swap_info->used_size = swap->used_size;
    return 0;
}

#endif